# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from lstm_model import LSTM_net
from gcn_model import GCN

# Hybrid Model Section
# Add the Transformer model before the GCN model
class Transformer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Transformer, self).__init__()
        # Add hidden layers and use the ReLu activation function.
        self.linear = nn.Linear(input_dim,output_dim)
        self.relu = nn.ReLU()

        # Dropout layer, preventing overfitting
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # linear transformation
        x = self.linear(x)
        x = self.relu(x)
        return x
# mixed model（Transformer+GCN+LSTM）
class Hybrid_Network(nn.Module):
    def __init__(self, feat_dim, stat_dim,T):
        super(Hybrid_Network, self).__init__()
        self.embed_dim = feat_dim
        self.stat_dim = stat_dim
        self.T = T

        # Integrate the Transformer model
        self.transformer = Transformer(feat_dim, feat_dim)
        self.gnn_model = GCN(feat_dim, stat_dim, self.T)
        self.lstm = LSTM_net()


        self.linear1 = nn.Linear(in_features=64,
                                out_features=32,
                                bias=True)
        self.linear = nn.Linear(in_features=64,
                                out_features=2,
                                bias=True)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, feat_Matrix, X_Node, X_Neis, edge_type_index, dg_list, Lstm_feature):
        # Transform the input features
        transformed_feat = self.transformer(feat_Matrix)
        gnn_result = self.gnn_model(feat_Matrix, X_Node, X_Neis, edge_type_index, dg_list)
        lstm_out = self.lstm(Lstm_feature)
        network_out = torch.cat([gnn_result.view(-1, 32), lstm_out.view(-1, 32)], 1)
        final_out = self.softmax(self.linear(network_out))
        return final_out




